{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Counterfactuals guided by prototypes on California housing dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook goes through an example of [prototypical counterfactuals](../methods/CFProto.ipynb) using [k-d trees](https://en.wikipedia.org/wiki/K-d_tree) to build the prototypes. Please check out [this notebook](./cfproto_mnist.ipynb) for a more in-depth application of the method on MNIST using (auto-)encoders and trust scores.\n",
    "\n",
    "In this example, we will train a simple neural net to predict whether house prices in California districts are above the median value or not. We can then find a counterfactual to see which variables need to be changed to increase or decrease a house price above or below the median value."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "Note\n",
    "\n",
    "To enable support for CounterfactualProto, you may need to run\n",
    "    \n",
    "```bash\n",
    "pip install alibi[tensorflow]\n",
    "```\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TF version:  2.7.4\n",
      "Eager execution enabled:  False\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import os\n",
    "os.environ[\"TF_USE_LEGACY_KERAS\"] = \"1\"\n",
    "\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel(40) # suppress deprecation messages\n",
    "tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs \n",
    "from tensorflow.keras.layers import Dense, Input\n",
    "from tensorflow.keras.models import Model, load_model\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from alibi.explainers import CounterfactualProto\n",
    "\n",
    "print('TF version: ', tf.__version__)\n",
    "print('Eager execution enabled: ', tf.executing_eagerly()) # False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and prepare California housing dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "california = fetch_california_housing(as_frame=True)\n",
    "X = california.data.to_numpy()\n",
    "target = california.target.to_numpy()\n",
    "feature_names = california.feature_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MedInc</th>\n",
       "      <th>HouseAge</th>\n",
       "      <th>AveRooms</th>\n",
       "      <th>AveBedrms</th>\n",
       "      <th>Population</th>\n",
       "      <th>AveOccup</th>\n",
       "      <th>Latitude</th>\n",
       "      <th>Longitude</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>8.3252</td>\n",
       "      <td>41.0</td>\n",
       "      <td>6.984127</td>\n",
       "      <td>1.023810</td>\n",
       "      <td>322.0</td>\n",
       "      <td>2.555556</td>\n",
       "      <td>37.88</td>\n",
       "      <td>-122.23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8.3014</td>\n",
       "      <td>21.0</td>\n",
       "      <td>6.238137</td>\n",
       "      <td>0.971880</td>\n",
       "      <td>2401.0</td>\n",
       "      <td>2.109842</td>\n",
       "      <td>37.86</td>\n",
       "      <td>-122.22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.2574</td>\n",
       "      <td>52.0</td>\n",
       "      <td>8.288136</td>\n",
       "      <td>1.073446</td>\n",
       "      <td>496.0</td>\n",
       "      <td>2.802260</td>\n",
       "      <td>37.85</td>\n",
       "      <td>-122.24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5.6431</td>\n",
       "      <td>52.0</td>\n",
       "      <td>5.817352</td>\n",
       "      <td>1.073059</td>\n",
       "      <td>558.0</td>\n",
       "      <td>2.547945</td>\n",
       "      <td>37.85</td>\n",
       "      <td>-122.25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.8462</td>\n",
       "      <td>52.0</td>\n",
       "      <td>6.281853</td>\n",
       "      <td>1.081081</td>\n",
       "      <td>565.0</td>\n",
       "      <td>2.181467</td>\n",
       "      <td>37.85</td>\n",
       "      <td>-122.25</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   MedInc  HouseAge  AveRooms  AveBedrms  Population  AveOccup  Latitude  \\\n",
       "0  8.3252      41.0  6.984127   1.023810       322.0  2.555556     37.88   \n",
       "1  8.3014      21.0  6.238137   0.971880      2401.0  2.109842     37.86   \n",
       "2  7.2574      52.0  8.288136   1.073446       496.0  2.802260     37.85   \n",
       "3  5.6431      52.0  5.817352   1.073059       558.0  2.547945     37.85   \n",
       "4  3.8462      52.0  6.281853   1.081081       565.0  2.181467     37.85   \n",
       "\n",
       "   Longitude  \n",
       "0    -122.23  \n",
       "1    -122.22  \n",
       "2    -122.24  \n",
       "3    -122.25  \n",
       "4    -122.25  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "california.data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each row represents a whole census group. Explanation of features:\n",
    "\n",
    " - `MedInc` - median income in block group\n",
    " - `HouseAge` - median house age in block group\n",
    " - `AveRooms` - average number of rooms per household\n",
    " - `AveBedrms` - average number of bedrooms per household\n",
    " - `Population` - block group population\n",
    " - `AveOccup` - average number of household members\n",
    " - `Latitude` - block group latitude\n",
    " - `Longitude` - block group longitude\n",
    " \n",
    " For more details on the dataset, refer to the [scikit-learn documentation](https://scikit-learn.org/stable/datasets/real_world.html#california-housing-dataset)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Transform into classification task: target becomes whether house price is above the overall median or not"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.zeros((target.shape[0],))\n",
    "y[np.where(target > np.median(target))[0]] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Standardize data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu = X.mean(axis=0)\n",
    "sigma = X.std(axis=0)\n",
    "X = (X - mu) / sigma"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define train and test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "y_train = to_categorical(y_train)\n",
    "y_test = to_categorical(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nn_model():\n",
    "    x_in = Input(shape=(8,))\n",
    "    x = Dense(40, activation='relu')(x_in)\n",
    "    x = Dense(40, activation='relu')(x)\n",
    "    x_out = Dense(2, activation='softmax')(x)\n",
    "    nn = Model(inputs=x_in, outputs=x_out)\n",
    "    nn.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])\n",
    "    return nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = nn_model()\n",
    "nn.summary()\n",
    "nn.fit(X_train, y_train, batch_size=64, epochs=500, verbose=0)\n",
    "nn.save('nn_california.h5', save_format='h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy:  0.87863374\n"
     ]
    }
   ],
   "source": [
    "nn = load_model('nn_california.h5')\n",
    "score = nn.evaluate(X_test, y_test, verbose=0)\n",
    "print('Test accuracy: ', score[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate counterfactual guided by the nearest class prototype"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Original instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = X_test[1].reshape((1,) + X_test[1].shape)\n",
    "shape = X.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run counterfactual:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# define model\n",
    "nn = load_model('nn_california.h5')\n",
    "\n",
    "# initialize and fit the explainer\n",
    "cf = CounterfactualProto(nn, shape, use_kdtree=True, theta=10., max_iterations=1000,\n",
    "                         feature_range=(X_train.min(axis=0), X_train.max(axis=0)), \n",
    "                         c_init=1., c_steps=10)\n",
    "\n",
    "cf.fit(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate a counterfactual\n",
    "explanation = cf.explain(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The prediction flipped from 0 (value below the median) to 1 (above the median):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original prediction: 0\n",
      "Counterfactual prediction: 1\n"
     ]
    }
   ],
   "source": [
    "print(f'Original prediction: {explanation.orig_class}')\n",
    "print(f'Counterfactual prediction: {explanation.cf[\"class\"]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the counterfactual. To make the results more interpretable, we will first undo the pre-processing step and then check where the counterfactual differs from the original instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AveOccup: -0.9049749915631999\n",
      "Latitude: -0.31885583625280134\n"
     ]
    }
   ],
   "source": [
    "orig = X * sigma + mu\n",
    "counterfactual = explanation.cf['X'] * sigma + mu\n",
    "delta = counterfactual - orig\n",
    "for i, f in enumerate(feature_names):\n",
    "    if np.abs(delta[0][i]) > 1e-4:\n",
    "        print(f'{f}: {delta[0][i]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So in order for the model to consider the census group as having above median house prices, the average occupancy would have to be lower by almost a whole household member, and the location of the census group would need to shift slightly South."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Comparing the original instance and the counterfactual side-by-side:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MedInc</th>\n",
       "      <th>HouseAge</th>\n",
       "      <th>AveRooms</th>\n",
       "      <th>AveBedrms</th>\n",
       "      <th>Population</th>\n",
       "      <th>AveOccup</th>\n",
       "      <th>Latitude</th>\n",
       "      <th>Longitude</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.5313</td>\n",
       "      <td>30.0</td>\n",
       "      <td>5.039384</td>\n",
       "      <td>1.193493</td>\n",
       "      <td>1565.0</td>\n",
       "      <td>2.679795</td>\n",
       "      <td>35.14</td>\n",
       "      <td>-119.46</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   MedInc  HouseAge  AveRooms  AveBedrms  Population  AveOccup  Latitude  \\\n",
       "0  2.5313      30.0  5.039384   1.193493      1565.0  2.679795     35.14   \n",
       "\n",
       "   Longitude  \n",
       "0    -119.46  "
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(orig, columns=feature_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MedInc</th>\n",
       "      <th>HouseAge</th>\n",
       "      <th>AveRooms</th>\n",
       "      <th>AveBedrms</th>\n",
       "      <th>Population</th>\n",
       "      <th>AveOccup</th>\n",
       "      <th>Latitude</th>\n",
       "      <th>Longitude</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.5313</td>\n",
       "      <td>30.0</td>\n",
       "      <td>5.039384</td>\n",
       "      <td>1.193493</td>\n",
       "      <td>1565.000004</td>\n",
       "      <td>1.77482</td>\n",
       "      <td>34.821144</td>\n",
       "      <td>-119.46</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   MedInc  HouseAge  AveRooms  AveBedrms   Population  AveOccup   Latitude  \\\n",
       "0  2.5313      30.0  5.039384   1.193493  1565.000004   1.77482  34.821144   \n",
       "\n",
       "   Longitude  \n",
       "0    -119.46  "
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(counterfactual, columns=feature_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean up:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.remove('nn_california.h5')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
